{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.9 GoogleNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "import sys\n",
    "import d2lzh_pytorch as d2l\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Inception(nn.Module):\n",
    "    def __init__(self, in_c, c1, c2, c3, c4):\n",
    "        # c1 - c4为每条线路里的层的输出通道数\n",
    "        super(Inception, self).__init__()\n",
    "        self.p1_1 = nn.Conv2d(in_c, c1, kernel_size=1)\n",
    "        \n",
    "        self.p2_1 = nn.Conv2d(in_c, c2[0], kernel_size=1)\n",
    "        self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)\n",
    "        \n",
    "        self.p3_1 = nn.Conv2d(in_c, c3[0], kernel_size=1)\n",
    "        self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)\n",
    "        \n",
    "        self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)\n",
    "        self.p4_2 = nn.Conv2d(in_c, c4, kernel_size=1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        p1 = F.relu(self.p1_1(x))\n",
    "        p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))\n",
    "        p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))\n",
    "        p4 = F.relu(self.p4_2(self.p4_1(x)))\n",
    "        return torch.cat((p1, p2, p3, p4), dim=1)  # 在通道维上连接输出"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
    "                  nn.ReLU(),\n",
    "                  nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),\n",
    "                  nn.Conv2d(64, 192, kernel_size=3, padding=1),\n",
    "                  nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),\n",
    "                  Inception(256, 128, (128, 192), (32, 96), 64),\n",
    "                  nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),\n",
    "                  Inception(512, 160, (112, 224), (24, 64), 64),\n",
    "                  Inception(512, 128, (128, 256), (24, 64), 64),\n",
    "                  Inception(512, 112, (144, 288), (32, 64), 64),\n",
    "                  Inception(528, 256, (160, 320), (32, 128), 128),\n",
    "                  nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),\n",
    "                  Inception(832, 384, (192, 384), (48, 128), 128),\n",
    "                  d2l.GlobalAvgPool2d())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(b1, b2, b3, b4, b5, d2l.FlattenLayer(),\n",
    "                   nn.Linear(1024, 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output shape:  torch.Size([1, 64, 24, 24])\n",
      "output shape:  torch.Size([1, 192, 12, 12])\n",
      "output shape:  torch.Size([1, 480, 6, 6])\n",
      "output shape:  torch.Size([1, 832, 3, 3])\n",
      "output shape:  torch.Size([1, 1024, 1, 1])\n",
      "output shape:  torch.Size([1, 1024])\n",
      "output shape:  torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand(1, 1, 96, 96)\n",
    "for blk in net.children():\n",
    "    X = blk(X)\n",
    "    print('output shape: ', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 2.3029, train acc 0.101, test acc 0.100, time 485.8 sec\n",
      "epoch 2, loss 2.3029, train acc 0.098, test acc 0.100, time 492.3 sec\n",
      "epoch 3, loss 2.3028, train acc 0.100, test acc 0.100, time 494.3 sec\n",
      "epoch 4, loss 2.3028, train acc 0.099, test acc 0.100, time 487.0 sec\n",
      "epoch 5, loss 2.3029, train acc 0.098, test acc 0.100, time 485.8 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 16  # 128\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
    "\n",
    "lr, num_epochs = 1e-3, 5\n",
    "optimizer = torch.optim. Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, \n",
    "              device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
